import torch 
import tensorly as tly

import collections

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import tltorch
import math
import collections
from .matrix_conv import Conv2d_USV

# collections.Iterable = collections.abc.Iterable
# collections.Mapping = collections.abc.Mapping
# collections.MutableSet = collections.abc.MutableSet
# collections.MutableMapping = collections.abc.MutableMapping


class Lenet5(torch.nn.Module):
    def __init__(self,args):
        """  
        initializer for Lenet5.
        NEEDED ATTRIBUTES TO USE dlr_opt:
        self.layer
        NEEDED METHODS TO USE dlr_opt:
        self.forward : standard forward of the NN
        self.update_step : updates the step of all the low rank layers inside the neural net
        self.populate_gradients : method used to populate the gradients inside the neural network in one unique function
        """
        super(Lenet5, self).__init__()
        self.args  = args
        if args.deco == 'cp':
            self.lr_model = torch.nn.Sequential(
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(in_channels = 1, out_channels = 20, kernel_size = 5, stride=1), rank=1-args.tau, decompose_weights=True, factorization='cp'),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size = 2, stride=2), 
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(in_channels = 20, out_channels = 50, kernel_size = 5, stride=1), rank=1-args.tau, decompose_weights=True, factorization='cp'),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size = 2, stride=2),
                torch.nn.Flatten(),
                torch.nn.Linear(800,out_features = 500),  
                torch.nn.ReLU(),
                torch.nn.Linear(500,out_features = 10)
            )
        elif args.deco == 'tucker':
            self.lr_model = torch.nn.Sequential(
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(in_channels = 1, out_channels = 20, kernel_size = 5, stride=1), rank=1-args.tau, decompose_weights=True, factorization='tucker'),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size = 2, stride=2),
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(in_channels = 20, out_channels = 50, kernel_size = 5, stride=1), rank=1-args.tau, decompose_weights=True, factorization='tucker'),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size = 2, stride=2),
                torch.nn.Flatten(),
                torch.nn.Linear(800,out_features = 500),  
                torch.nn.ReLU(),
                torch.nn.Linear(500,out_features = 10)
            )
        elif args.deco == 'mat':
            r1,r2 = math.ceil((1-args.tau)*20),math.ceil((1-args.tau)*50)
            self.lr_model = torch.nn.Sequential(
                Conv2d_USV(in_channels = 1, out_channels = 20, kernel_size = 5, stride=1,rank = r1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size = 2, stride=2),
                Conv2d_USV(in_channels = 20, out_channels = 50, kernel_size = 5, stride=1,rank = r2),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size = 2, stride=2),
                torch.nn.Flatten(),
                torch.nn.Linear(800,out_features = 500),  
                torch.nn.ReLU(),
                torch.nn.Linear(500,out_features = 10)
            )
    def forward(self, x):
        return self.lr_model(x)
    

